DESI DR1 Stream Tutorial - Work In Progess!!¶

Welcome to the 2025 version of our stellar stream characterization tutorial notebook. This notebook will walk you through using data from the DESI Milky Way Survey data from the first full public release!

Below, we're going to import the packages we'll be using.

#-----------------------#

#Look here for stuff to change throughout the notebook!

#-----------------------#
In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy as sp
import scipy.stats as stats
from astropy.io import fits
from astropy import table
import matplotlib
import matplotlib.patheffects as path_effects
import importlib
import stream_functions as stream_funcs
importlib.reload(stream_funcs)
import emcee
import corner
from astropy import units as u
from collections import OrderedDict
import time
from scipy import optimize, stats
import matplotlib.colors as mcolors
colors = mcolors.CSS4_COLORS
color_names = list(colors.keys())
import streamTutorial as st
import copy
importlib.reload(st);

Next, let's input our DESI MWS DR1 data.

In [2]:
importlib.reload(st)
# Add the path to the DESI data and STREAMFINDER data below
#-----------------------#
desi_path = '../mwsall-pix-iron.fits'
sf_path = './data/streamfinder_gaiadr3.fits'
#-----------------------#

Data = st.Data(desi_path, sf_path)

print('Now our desi data has been loaded under Data.desi_data')
Length of DESI Data before Cuts: 6372607
Length after NaN cut: 4075716
Adding empirical FEH calibration (can find uncalibrated data in column['FEH_uncalib])
Now our desi data has been loaded under Data.desi_data

Pick Stream¶

We'll pick a stream to work with by entering it in

st.stream(Data, streamName='SoI-I21', streamNo=42)

Initializing this stream object will add an attribute to the Data object, accessed either through

Data.confirmed_sf_and_desi or SoI.data.confirmed_sf_and_desi

It will also calculate the stream coordinates $\phi_1$ and $\phi_2$. This is achieved by rotating the right ascension and declination such that the length of the stream lies along $\phi_2 \sim 0$, and the center of the stream is $\phi_1 \sim 0$.

In [4]:
#del SoI
In [33]:
importlib.reload(stream_funcs)

#-----------------------#
streamName= 'Gaia-6-I21'#'Fjorm-I21' #'Sylgr-I21'
streamNo= 59# 47 #42
#-----------------------#

importlib.reload(st)
SoI = st.stream(Data, streamName, streamNo)
Importing galstreams module...
Initializing galstreams library from master_log... 
8.032320598835131
Creating combined DataFrame of SF and DESI
No stars were cut - cut_confirmed_sf_and_desi is empty
Number of stars in SF: 145, Number of DESI and SF stars: 49
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 96

Lets take a look at the rotated stream below

In [34]:
importlib.reload(st)
plt_soi = st.StreamPlotter(SoI)

plt_soi.on_sky(stream_frame=False)
plt_soi.on_sky(stream_frame=True)
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:652: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image
No description has been provided for this image

Trimming DESI Data¶

We have lots of DESI data, lets cut that down with the following steps:

  • Perform a distance cut to remove stars too nearby
  • Perform an RA/DEC cut to remove stars on different parts of the sky
  • Trim stars with kinematics nothing like the STREAMFINDER stars
  • Trim stars that are too metal rich
  • Trim stars that fall outside of the best-fit stellar population model

The first two bullet points can be achieved using a handy function we have stored in stream_functions.py

stream_funcs.threeD_max_min_mask

In [73]:
importlib.reload(st)
selection_fine = st.Selection(SoI.data.desi_data)

# We want to get the ra and dec from STREAMFINDER stars for this function so we can cut around it


#-----------------------#
ra_cut = 5 #deg
dec_cut = 10 #deg
#-----------------------#

selection_fine.add_mask(name='3D',
        mask_func=lambda df: stream_funcs.threeD_max_min_mask(
        df['TARGET_RA'],        
        df['TARGET_DEC'],         
        df['PARALLAX'],       
        df['PARALLAX_ERROR'],   
        SoI.data.SoI_streamfinder['RAdeg'],         
        SoI.data.SoI_streamfinder['DEdeg'],        
        SoI.min_dist,              
        ra_cut,dec_cut) #<-----------------------# dec cut, ra cut (wide cuts for now)
)
Selection object created for DataFrame with 4075716 rows.
Mask added: '3D'

Once we've done all our cuts, we can get all our masks in one using final_mask = st.Selection.get_final_mask()

In [74]:
importlib.reload(st)
# Lets take a look at our trimmed DESI data

final_mask = selection_fine.get_final_mask()

# Using the new method that replaces the 4-line pattern
trimmed_stream = SoI.mask_stream(final_mask)

plt_trim = st.StreamPlotter(trimmed_stream)
plt_trim.on_sky(showStream=True, background=True, stream_frame=False)
plt_trim.on_sky(showStream=True, background=True)
Combining masks...
...'3D' selected 69353 stars
Selection: 69353 / 4075716 stars.
Created cut_confirmed_sf_and_desi with 2 stars that were filtered out
Number of stars in SF: 145, Number of DESI and SF stars: 47
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 98
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:652: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image
No description has been provided for this image

Lets also look at our parallax cut:

In [75]:
plt_trim.plx_cut()

#del plt_trim, trimmed_desi, trimmed_stream # clear memory
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:718: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image

Now that we've visualized the stream in the stream-frame, we can get a bit more specific with our on-sky cuts.

In [76]:
#-----------------------#
phi2_wiggle = 5 #[deg]
#-----------------------#

selection_fine.add_mask(name='phi2',
        mask_func=lambda df: (df['phi2'] < phi2_wiggle) & (df['phi2'] > -1*phi2_wiggle))
Mask added: 'phi2'
In [77]:
# del plt_trim
In [78]:
importlib.reload(st)
# Lets take a look at our trimmed DESI data
mask = selection_fine.get_masks(['3D', 'phi2'])

trimmed_stream = SoI.mask_stream(mask)

plt_trim = st.StreamPlotter(trimmed_stream)
plt_trim.on_sky(showStream=True, background=True, stream_frame=False)
plt_trim.on_sky(showStream=True, background=True)
...'3D' selected 69353 stars
...'phi2' selected 423352 stars
Selection for specified masks: 33744 / 4075716 stars.
Created cut_confirmed_sf_and_desi with 2 stars that were filtered out
Number of stars in SF: 145, Number of DESI and SF stars: 47
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 98
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:652: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image
No description has been provided for this image

Metallicity and Isochrone Cuts¶

Metallicity¶

Lets grab the stream's known metallicity from STREAMFINDER

In [79]:
sf3_table = pd.read_csv('./data/sf3_only_table.csv')


sf_streamname = streamName.rsplit('-', 1)[0] #If this fails, manually enter the stream name from the first table.
metallicity = sf3_table[sf3_table['Stream'] == sf_streamname]['Metallicities'].values
print(f"{sf_streamname} Metallicity: {metallicity}")
print(f'Mass Fraction (Z) Guess: {0.0181 * 10 ** metallicity}')
Gaia-6 Metallicity: [-1.53]
Mass Fraction (Z) Guess: [0.00053417]
In [80]:
plt_trim.plot_params['background']['alpha']=0.05
plot = plt_trim.feh_plot(showStream=True, background=True)
ax = plot[1]
ax.axhline(metallicity, color='g', linestyle='solid', label='Metallicity from SF3', alpha=0.5)
ax.legend()
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:896: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

Out[80]:
<matplotlib.legend.Legend at 0x30b4a3dd0>
No description has been provided for this image

Now lets trim stars that are too metal rich!

In [81]:
#-----------------------#
feh_cut = 0 # [dex]
#-----------------------#

selection_fine.add_mask(name='feh',
        mask_func=lambda df: (df['FEH'] < feh_cut))
Mask added: 'feh'
In [82]:
#del trimmed_stream, plt_trim # clear memory
importlib.reload(st)
mask = selection_fine.get_masks(['3D', 'phi2', 'feh'])

trimmed_stream = SoI.mask_stream(mask)


plt_trim = st.StreamPlotter(trimmed_stream)
plot = plt_trim.feh_plot(showStream=True, background=True)
ax = plot[1]
ax.axhline(metallicity, color='g', linestyle='solid', label='Metallicity from SF3', alpha=0.5)
ax.axhline(feh_cut, color='red', linestyle='--', label='Cut')
ax.legend(loc='upper right')
...'3D' selected 69353 stars
...'phi2' selected 423352 stars
...'feh' selected 3085218 stars
Selection for specified masks: 30708 / 4075716 stars.
Created cut_confirmed_sf_and_desi with 3 stars that were filtered out
Number of stars in SF: 145, Number of DESI and SF stars: 46
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 99
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:896: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

Out[82]:
<matplotlib.legend.Legend at 0x16cebdb10>
No description has been provided for this image

Isochrone¶

Lets try to fit an isochrone to the STREAMFINDER stars and trimming in one fell swoop.

In [83]:
#-----------------------#
colour_wiggle = 0.18 # [dex]
age = 13.5 # [Gyr] -> of form 10.0, 10.1, 10.2, etc
#-----------------------#
SoI.isochrone(metallicity, age=age) 

selection_fine.add_mask(name='iso',
        mask_func=lambda df: (stream_funcs.betw(SoI.data.desi_colour_idx, SoI.isochrone_fit(SoI.data.desi_abs_mag), colour_wiggle)))
Mass Fraction (Z): [0.00053417]

using ./data/dotter/iso_a13.5_z0.00057.dat
Using distance gradient
Mask added: 'iso'
In [84]:
importlib.reload(st)
mask = selection_fine.get_masks(['3D', 'phi2', 'feh', 'iso'])

trimmed_stream = SoI.mask_stream(mask)
trimmed_stream.isochrone(metallicity, age=age) 

plt_trim = st.StreamPlotter(trimmed_stream)
plt_trim.iso_plot(wiggle=colour_wiggle, background=True, showStream=True)
...'3D' selected 69353 stars
...'phi2' selected 423352 stars
...'feh' selected 3085218 stars
...'iso' selected 1228176 stars
Selection for specified masks: 14931 / 4075716 stars.
Created cut_confirmed_sf_and_desi with 20 stars that were filtered out
Number of stars in SF: 145, Number of DESI and SF stars: 29
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 116
Mass Fraction (Z): [0.00053417]

using ./data/dotter/iso_a13.5_z0.00057.dat
Using distance gradient
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:950: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image

Kinematic Cuts¶

DESI's radial velocity

In [85]:
importlib.reload(st)
plt_trim.plot_params['background']['alpha']=0.08
plt_trim.kin_plot(showStream=True, background=True, show_sf_only=False)
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:775: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:780: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:785: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image
In [86]:
#-----------------------#
vgsr_max=-200 #[km/s]
vgsr_min=-400
#-----------------------#

selection_fine.add_mask(name='VGSR',
        mask_func=lambda df: (df['VGSR'] < vgsr_max) & (df['VGSR'] > vgsr_min))
Mask added: 'VGSR'
In [87]:
#-----------------------#
pmra_wiggle = 6 # [mas/yr]
#-----------------------#

selection_fine.add_mask(
    name='PMRA',
    mask_func=lambda df: (
        (
            df['PMRA'] >= np.interp(df['phi1'], [SoI.data.SoI_streamfinder['phi1'].min(), SoI.data.SoI_streamfinder['phi1'].max()],
                                            [SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmin(), 'pmRA'] - pmra_wiggle,
                                             SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmax(), 'pmRA'] - pmra_wiggle])
        ) &
        (
            df['PMRA'] <= np.interp(df['phi1'], [SoI.data.SoI_streamfinder['phi1'].min(), SoI.data.SoI_streamfinder['phi1'].max()],
                                            [SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmin(), 'pmRA'] + pmra_wiggle,
                                             SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmax(), 'pmRA'] + pmra_wiggle])
        ) ))
Mask added: 'PMRA'
In [88]:
#-----------------------#
pmdec_wiggle = 6 # [mas/yr]
#-----------------------#


selection_fine.add_mask(
    name='PMDEC',
    mask_func=lambda df: (
        (
            df['PMDEC'] >= np.interp(df['phi1'], [SoI.data.SoI_streamfinder['phi1'].min(), SoI.data.SoI_streamfinder['phi1'].max()],
                                            [SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmin(), 'pmDE'] - pmdec_wiggle,
                                             SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmax(), 'pmDE'] - pmdec_wiggle])
        ) &
        (
            df['PMDEC'] <= np.interp(df['phi1'], [SoI.data.SoI_streamfinder['phi1'].min(), SoI.data.SoI_streamfinder['phi1'].max()],
                                            [SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmin(), 'pmDE'] + pmdec_wiggle,
                                             SoI.data.SoI_streamfinder.loc[SoI.data.SoI_streamfinder['phi1'].idxmax(), 'pmDE'] + pmdec_wiggle])
        ) ))
Mask added: 'PMDEC'

Lets look at our cuts!¶

In [89]:
selection_fine.list_masks()
Active masks:
- 3D
- phi2
- feh
- iso
- VGSR
- PMRA
- PMDEC
In [90]:
importlib.reload(st)
# Lets take a look at our trimmed DESI data
mask = selection_fine.get_masks(['3D', 'phi2', 'feh','iso', 'VGSR', 'PMRA', 'PMDEC'])

# Using the new method that replaces the 4-line pattern
# VGSR is now automatically computed for confirmed_sf_not_desi within mask_stream()
trimmed_stream = SoI.mask_stream(mask)

plt_trim = st.StreamPlotter(trimmed_stream)
plt_trim.plot_params['sf_in_desi']['alpha']= 0.9
plt_trim.plot_params['background']['alpha']= 0.9
plt_trim.plot_params['background']['s'] = 10
#-----------------------#
plt_trim.kin_plot(showStream=True, show_sf_only=False, background=True) # You can change showStream to True to see where the STREAMFINDER stars are in this space.
                                                                       # If there aren't many SF stars in DESI, you can show the stars not in DESI as a guide by setting show_sf_only to True
#-----------------------#
...'3D' selected 69353 stars
...'phi2' selected 423352 stars
...'feh' selected 3085218 stars
...'iso' selected 1228176 stars
...'VGSR' selected 46599 stars
...'PMRA' selected 1811662 stars
...'PMDEC' selected 2262018 stars
Selection for specified masks: 154 / 4075716 stars.
Created cut_confirmed_sf_and_desi with 35 stars that were filtered out
Number of stars in SF: 145, Number of DESI and SF stars: 14
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 131
/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:775: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:780: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

/Users/nasserm/Documents/vscode/research/streamTut/DESI-DR1_streamTutorial/streamTutorial.py:785: UserWarning: You passed a edgecolor/edgecolors ('k') for an unfilled marker ('x').  Matplotlib is ignoring the edgecolor in favor of the facecolor.  This behavior may change in the future.
  **self.plot_params['sf_in_desi_notsel']

No description has been provided for this image

Lets look at an example of a stream (Sylgr-I21) being spotted once our initial cuts have been made. We can see lines of overdensities in all three kinematic dimensions!

Furthermore, we can overlay the STREAMFINDER stars to see that this overdensity lies on the previously disovered stream, and which SF stars we have cutout.

Stream Image Stream Image

These cuts are quite restrictive, we'll widen them now to have more stars to run our mixture model on.

In [91]:
import copy
selection_mcmc = copy.copy(selection_fine)
In [92]:
#-----------------------#
colour_wiggle = 0.25 # [dex]

pmdec_max = 0 # [mas/yr]
pmdec_min = -15

vgsr_max=  -100 # [km/s]
vgsr_min= -400

pmra_max = 0 # [mas/yr]
pmra_min = -15

feh_wide = 0
#-----------------------#

selection_fine.add_mask(name='feh_wide',
        mask_func=lambda df: (df['FEH'] < feh_wide))

selection_mcmc.add_mask(name='iso_wide',
        mask_func=lambda df: (stream_funcs.betw(SoI.data.desi_colour_idx, SoI.isochrone_fit(SoI.data.desi_abs_mag), colour_wiggle)))

selection_mcmc.add_mask(name='VGSR_wide',
        mask_func=lambda df: (df['VGSR'] < vgsr_max) & (df['VGSR'] > vgsr_min))

selection_mcmc.add_mask(name='PMRA_wide',
        mask_func=lambda df: (df['PMRA'] < pmra_max) & (df['PMRA'] > pmra_min))

selection_mcmc.add_mask(name='PMDEC_wide',
        mask_func=lambda df: (df['PMDEC'] < pmdec_max) & (df['PMDEC'] > pmdec_min))
Mask added: 'feh_wide'
Mask added: 'iso_wide'
Mask added: 'VGSR_wide'
Mask added: 'PMRA_wide'
Mask added: 'PMDEC_wide'
In [93]:
mcmc_mask = selection_mcmc.get_masks(['3D', 'phi2', 'feh_wide','iso_wide', 'VGSR_wide', 'PMRA_wide', 'PMDEC_wide'])
...'3D' selected 69353 stars
...'phi2' selected 423352 stars
...'feh_wide' selected 3085218 stars
...'iso_wide' selected 1586023 stars
...'VGSR_wide' selected 655452 stars
...'PMRA_wide' selected 2450141 stars
...'PMDEC_wide' selected 2897913 stars
Selection for specified masks: 1764 / 4075716 stars.
In [94]:
importlib.reload(st)
tomcmc_stream = SoI.mask_stream(mcmc_mask)

plt_trim = st.StreamPlotter(tomcmc_stream)
plt_trim.plot_params['sf_in_desi']['alpha']= 0.9
plt_trim.plot_params['background']['alpha']= 0.4
plt_trim.plot_params['background']['s'] = 1
#-----------------------#
plt_trim.sixD_plot(showStream=True, show_sf_only=False, background=True) # You can change showStream to True to see where the STREAMFINDER stars are in this space.
                                                                       # If there aren't many SF stars in DESI, you can show the stars not in DESI as a guide by setting show_sf_only to True
#-----------------------#
Created cut_confirmed_sf_and_desi with 30 stars that were filtered out
Number of stars in SF: 145, Number of DESI and SF stars: 19
Saved merged DataFrame as self.data.confirmed_sf_and_desi_b
Stars only in SF3: 126
Out[94]:
(<Figure size 1000x1500 with 5 Axes>,
 array([<Axes: ylabel='$\\phi_2$'>, <Axes: ylabel='V$_{GSR}$ (km/s)'>,
        <Axes: ylabel='$\\mu_{\\alpha}$ [mas/yr]'>,
        <Axes: ylabel='$\\mu_{\\delta}$ [mas/yr]'>,
        <Axes: xlabel='$\\phi_1$', ylabel='[Fe/H]'>], dtype=object))
No description has been provided for this image

Optimizing¶

If you see a stream above, you may continue with the rest of the notebook!

Before jumping into running MCMC, we will find a good starting point using scipy.minimize.

Streams aren't flat lines in $\phi_1$ dynamically, so we want to allow our fits to vary along the track.

#-----------------------#
no_of_spline_points = 5
#-----------------------#

Typically we want between 3 and 5 spline points.

In [116]:
importlib.reload(st)
#-----------------------#
no_of_spline_points = 5
#-----------------------#

# Define truncation parameters based on our selection cuts
truncation_params = {
    'vgsr_min': vgsr_min, 'vgsr_max': vgsr_max,
    'feh_min': -4.0, 'feh_max': feh_wide,
    'pmra_min': pmra_min, 'pmra_max': pmra_max,
    'pmdec_min': pmdec_min, 'pmdec_max': pmdec_max
}

MCMeta = st.MCMeta(no_of_spline_points, tomcmc_stream, trimmed_stream.data.confirmed_sf_and_desi, truncation_params=truncation_params)
Making stream initial guess based on galstream and STREAMFINDER...
Stream VGSR dispersion from trimmed SF: 5.08 km/s
Stream mean metallicity from trimmed SF: -1.15 +- 0.145 dex
Stream PMRA dispersion from trimmed SF: 0.23 mas/yr
Stream PMDEC dispersion from trimmed SF: 0.52 mas/yr
Making background initial guess...
Background velocity: -151.44 +- 45.62 km/s
Background metallicity: -1.50 +- 0.456 dex
Background PMRA: -3.10 +- 2.96 mas/yr
Background PMDEC: -6.63 +- 3.32 mas/yr

We're doing a mixture model, but what are we mixing? Lets take a look at our intial guesses below:

In [117]:
plt_mcmeta = st.StreamPlotter(MCMeta)

# Show our initial parameter guesses
print("Initial Parameters:")
for param, value in MCMeta.initial_params.items():
    if isinstance(value, np.ndarray):
        print(f"{param}: {value}")
    else:
        print(f"{param}: {value:.4f}")

# Plot the Gaussian mixture model based on our initial guesses
plt_mcmeta.gaussian_mixture_plot(showStream=True, background=True)
Initial Parameters:
lsigvgsr: 0.7062
vgsr_spline_points: [-266.03920172 -224.62881156 -206.89923991 -212.85048676 -242.48255212]
feh1: -1.1451
lsigfeh: -0.8390
lsigpmra: -0.6476
pmra_spline_points: [-7.32072666 -8.09759399 -8.28275927 -7.8762225  -6.87798368]
lsigpmdec: -0.2882
pmdec_spline_points: [-10.31928752  -9.31972844  -8.08484259  -6.61462996  -4.90909055]
bv: -151.4366
lsigbv: 1.6592
bfeh: -1.5036
lsigbfeh: -0.3413
bpmra: -3.0965
lsigbpmra: 0.4718
bpmdec: -6.6303
lsigbpmdec: 0.5213
Out[117]:
(<Figure size 900x900 with 4 Axes>,
 array([[<Axes: xlabel='V$_{GSR}$ (km/s)'>, <Axes: xlabel='[Fe/H]'>],
        [<Axes: xlabel='$\\mu_{RA}$ (mas/yr)'>,
         <Axes: xlabel='$\\mu_{DEC}$ (mas/yr)'>]], dtype=object))
No description has been provided for this image

Prior¶

In [118]:
vgsr_range_wiggle = 50
pmra_range_wiggle = 5
pmdec_range_wiggle = 5
lsigvgsr_r = (-4, 4)

vgsr_ranges = [(v - vgsr_range_wiggle, v + vgsr_range_wiggle) for v in MCMeta.initial_params['vgsr_spline_points']]
pmra_ranges = [(v - pmra_range_wiggle, v + pmra_range_wiggle) for v in MCMeta.initial_params['pmra_spline_points']]
pmdec_ranges = [(v - pmdec_range_wiggle, v + pmdec_range_wiggle) for v in MCMeta.initial_params['pmdec_spline_points']]

#-----------------------#
feh1_r = (-2,0)
#-----------------------#
lsigfeh_r = (-2,4)
lsigpmra_r = (-5, 1)
lsigpmdec_r = (-5, 1)
bv_r = (-400, 400)
lsigbv_r = (-2,4)
bfeh_r = (-4, 4)
lsigbfeh_r = (-2,2)
bpmra_r = (-20, 20)
lsigbpmra_r = (-2,3)
bpmdec_r = (-20, 10)
lsigbpmdec_r = (-2,3)

pstream_r = (0.0, 1.0)

prior = [
    pstream_r,              # pstream (single value)
    *vgsr_ranges,           # vgsr_spline_points (5 values)
    lsigvgsr_r,             # lsigvgsr (single value)
    feh1_r, lsigfeh_r,      # feh parameters
    *pmra_ranges,           # pmra_spline_points (5 values)  
    lsigpmra_r,             # lsigpmra (single value)
    *pmdec_ranges,          # pmdec_spline_points (5 values)
    lsigpmdec_r,            # lsigpmdec (single value)
    bv_r, lsigbv_r, bfeh_r, lsigbfeh_r,  # background parameters
    bpmra_r, lsigbpmra_r, bpmdec_r, lsigbpmdec_r
]
In [119]:
importlib.reload(stream_funcs)

phi1_spline_points = MCMeta.phi1_spline_points  # x-coordinates for splines

spline_k = MCMeta.spline_k

# Create initial parameter array - pstream and lsigvgsr are now single constant values
p0_guess = [
    0.1,                                                 # pstream (constant stream fraction)
    MCMeta.initial_params['vgsr_spline_points'],         # VGSR spline points
    MCMeta.initial_params['lsigvgsr'],                   # lsigvgsr (constant log velocity dispersion)
    MCMeta.initial_params['feh1'],                       # mean [Fe/H]
    MCMeta.initial_params['lsigfeh'],                    # log(sigma_[Fe/H])
    MCMeta.initial_params['pmra_spline_points'],         # PMRA spline points
    MCMeta.initial_params['lsigpmra'],                   # log(sigma_pmra)
    MCMeta.initial_params['pmdec_spline_points'],        # PMDEC spline points
    MCMeta.initial_params['lsigpmdec'],                  # log(sigma_pmdec)
    MCMeta.initial_params['bv'],                         # background VGSR
    MCMeta.initial_params['lsigbv'],                     # log(sigma_background_vgsr)
    MCMeta.initial_params['bfeh'],                       # background [Fe/H]
    MCMeta.initial_params['lsigbfeh'],                   # log(sigma_background_feh)
    MCMeta.initial_params['bpmra'],                      # background PMRA
    MCMeta.initial_params['lsigbpmra'],                  # log(sigma_background_pmra)
    MCMeta.initial_params['bpmdec'],                     # background PMDEC
    MCMeta.initial_params['lsigbpmdec']                  # log(sigma_background_pmdec)
]

# Updated parameter labels - pstream and lsigvgsr are now single parameters
param_labels = ['pstream', 'vgsr_spline_points', 'lsigvgsr', 'feh1', 'lsigfeh', 
                'pmra_spline_points', 'lsigpmra', 'pmdec_spline_points', 'lsigpmdec',
                'bv', 'lsigbv', 'bfeh', 'lsigbfeh', 'bpmra', 'lsigbpmra', 'bpmdec', 'lsigbpmdec']

vgsr_trunc = [MCMeta.truncation_params['vgsr_min'], MCMeta.truncation_params['vgsr_max']]
feh_trunc = [MCMeta.truncation_params['feh_min'], MCMeta.truncation_params['feh_max']]  
pmra_trunc = [MCMeta.truncation_params['pmra_min'], MCMeta.truncation_params['pmra_max']]
pmdec_trunc = [MCMeta.truncation_params['pmdec_min'], MCMeta.truncation_params['pmdec_max']]

array_lengths = [len(x) if isinstance(x, np.ndarray) else 1 for x in p0_guess]
flat_p0_guess = np.hstack(p0_guess)


# Fixed optimization function with correct signature
optfunc = lambda theta: -stream_funcs.spline_lnprob_1D(
    theta, prior, phi1_spline_points,  # Only phi1_spline_points needed
    tomcmc_stream.data.desi_data['VGSR'], tomcmc_stream.data.desi_data['VRAD_ERR'],
    tomcmc_stream.data.desi_data['FEH'], tomcmc_stream.data.desi_data['FEH_ERR'],
    tomcmc_stream.data.desi_data['PMRA'], tomcmc_stream.data.desi_data['PMRA_ERROR'],
    tomcmc_stream.data.desi_data['PMDEC'], tomcmc_stream.data.desi_data['PMDEC_ERROR'],
    tomcmc_stream.data.desi_data['phi1'], 
    trunc_fit=True, feh_fit=True, assert_prior=False, k=spline_k, 
    reshape_arr_shape=array_lengths,
    vgsr_trunc=vgsr_trunc, feh_trunc=feh_trunc, 
    pmra_trunc=pmra_trunc, pmdec_trunc=pmdec_trunc
)

# Run optimization
print("Running optimization...")
%time result = optimize.minimize(optfunc, flat_p0_guess, method="Nelder-Mead")
print(result.message)

# Reshape and process results
reshaped_result = stream_funcs.reshape_arr(result.x, array_lengths)
output = stream_funcs.get_paramdict(reshaped_result, labels=param_labels)

print("\nOptimized Parameters:")
for label, value in output.items():
    if label.startswith('l'):
        print(f"{label[1:]}: {10**value:.4f}")
    else:
        print(f"{label}: {value:.4f}" if isinstance(value, (int, float)) else f"{label}: {value}")
Running optimization...
CPU times: user 7.04 s, sys: 20.5 ms, total: 7.06 s
Wall time: 7.08 s
Maximum number of function evaluations has been exceeded.

Optimized Parameters:
pstream: 0.0341
vgsr_spline_points: [-223.72867231 -192.52493592 -197.42651142 -213.50434964 -224.0153508 ]
sigvgsr: 0.1473
feh1: -1.1832
sigfeh: 0.0100
pmra_spline_points: [-12.22746583  -8.38062923  -8.50946684  -7.87010161 -10.76854696]
sigpmra: 0.0339
pmdec_spline_points: [-15.31897716  -9.89349392  -8.06340032  -6.75020055  -7.13431764]
sigpmdec: 0.1672
bv: 152.6786
sigbv: 131.9339
bfeh: -1.4982
sigbfeh: 0.4161
bpmra: 3.3405
sigbpmra: 5.3268
bpmdec: -6.1635
sigbpmdec: 3.9881
In [120]:
# Add optimization results to MCMeta class
MCMeta.optimized_params = output
MCMeta.optimization_result = result

# Create a new plotter object with the optimized parameters
plt_optimized = st.StreamPlotter(MCMeta)


original_params = MCMeta.initial_params.copy()
optimized_for_plotting = output.copy()


if 'lsigvgsr' in optimized_for_plotting:
    optimized_for_plotting['lsigvgsr'] = optimized_for_plotting['lsigvgsr']

MCMeta.initial_params = optimized_for_plotting

print("\nPlotting optimized Gaussian mixture...")
plt_optimized.gaussian_mixture_plot(showStream=True, background=True)

# Restore original initial_params
MCMeta.initial_params = original_params

plt.suptitle('Optimized Gaussian Mixture Model', fontsize=14, y=0.98)
plt.tight_layout()
plt.show()

# Below checks if we picked good priors


nparams = len(param_labels)
nwalkers = 70

p0 = flat_p0_guess 
ep0 = np.zeros(len(p0)) + 0.01

# Generate walker positions around the starting point
p0s = np.random.multivariate_normal(p0, np.diag(ep0)**2, size=nwalkers)

# Clip pstream to valid range [0, 1] - first parameter only
p0s[:,0] = np.clip(p0s[:,0], 1e-10, 1 - 1e-10)

# Test likelihood for all walkers using the modified function
lkhds = [stream_funcs.spline_lnprob_1D(
    p0s[j], prior, phi1_spline_points, 
    tomcmc_stream.data.desi_data['VGSR'], tomcmc_stream.data.desi_data['VRAD_ERR'],
    tomcmc_stream.data.desi_data['FEH'], tomcmc_stream.data.desi_data['FEH_ERR'],
    tomcmc_stream.data.desi_data['PMRA'], tomcmc_stream.data.desi_data['PMRA_ERROR'],
    tomcmc_stream.data.desi_data['PMDEC'], tomcmc_stream.data.desi_data['PMDEC_ERROR'],
    tomcmc_stream.data.desi_data['phi1'], 
    trunc_fit=True, feh_fit=True, assert_prior=True, k=spline_k, 
    reshape_arr_shape=array_lengths,
    vgsr_trunc=vgsr_trunc, feh_trunc=feh_trunc, 
    pmra_trunc=pmra_trunc, pmdec_trunc=pmdec_trunc
) for j in range(nwalkers)]

# Check if prior is good - this is the key test from your original code
if sum(np.array(lkhds) > -9e9) == nwalkers:
    print('Your prior is good, you\'ve found something!')
elif sum(np.array(lkhds) > -9e9) != nwalkers:
    print('Your prior is too restrictive, try changing the values listed above!')

# Assert that all walkers have good likelihoods
assert np.all(np.array(lkhds) > -9e9), f"Only {sum(np.array(lkhds) > -9e9)}/{nwalkers} walkers have valid likelihoods"

print(f"All {nwalkers} walkers initialized successfully!")
Plotting optimized Gaussian mixture...
No description has been provided for this image
Your prior is good, you've found something!
All 70 walkers initialized successfully!
In [121]:
importlib.reload(st)

plt_enhanced = st.StreamPlotter(MCMeta) 
plt_enhanced.plot_params['sf_in_desi']['alpha'] = 0.9
plt_enhanced.plot_params['background']['alpha'] = 0.4
plt_enhanced.plot_params['background']['s'] = 1

fig, ax = plt_enhanced.sixD_plot(
    showStream=True, 
    show_sf_only=True, 
    background=True,
    show_initial_splines=True,
    show_optimized_splines=True,
    show_sf_errors=True  # Now enable error bars
)

plt.tight_layout()
plt.show()
No description has been provided for this image

EMCEE¶

In [122]:
from datetime import datetime
import os
note = 'first_try'
current_date = datetime.now().strftime("%y%m%d")
output_dir = f"./runs/{streamName}_{current_date}_{note}"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

#-----------------------#
backend = emcee.backends.HDFBackend(output_dir +'/'+ streamName + '_' + str(phi2_wiggle) + 'deg_' + str(no_of_spline_points) + 'spline_'+ '.h5')
#-----------------------#
backend.reset(nwalkers,len(p0))
In [137]:
%%time 
from multiprocessing import Pool
importlib.reload(stream_funcs)

#-----------------------#
nproc = 32
nburnin = 5000
nstep = 5000
use_optimized_start = True  # Toggle: True for optimized results, False for initial guess
#-----------------------#

# Choose starting positions based on toggle
if use_optimized_start:
    print("Using optimized parameters as starting positions...")
    start_params = result.x
    start_label = "optimized"
else:
    print("Using initial guess as starting positions...")
    start_params = flat_p0_guess
    start_label = "initial_guess"

print("Validating initial walker positions with the log-probability function...")
non_finite_walkers = []
for i, p in enumerate(p0s):
    # The arguments must match what you pass to EnsembleSampler
    args=(prior, phi1_spline_points,
          tomcmc_stream.data.desi_data['VGSR'].values, tomcmc_stream.data.desi_data['VRAD_ERR'].values,
          tomcmc_stream.data.desi_data['FEH'].values, tomcmc_stream.data.desi_data['FEH_ERR'].values,
          tomcmc_stream.data.desi_data['PMRA'].values, tomcmc_stream.data.desi_data['PMRA_ERROR'].values,
          tomcmc_stream.data.desi_data['PMDEC'].values, tomcmc_stream.data.desi_data['PMDEC_ERROR'].values,
          tomcmc_stream.data.desi_data['phi1'].values,
          True, False, True, spline_k, array_lengths,
          vgsr_trunc, feh_trunc, pmra_trunc, pmdec_trunc)

    log_prob = stream_funcs.spline_lnprob_1D(p, *args)
    if not np.isfinite(log_prob):
        print(f"Walker {i} has a non-finite log-probability: {log_prob}")
        non_finite_walkers.append((i, p, log_prob))

if non_finite_walkers:
    # You can add more detailed debugging here, for example, printing the problematic parameters
    raise ValueError(f"{len(non_finite_walkers)} walkers have non-finite log-probabilities. Halting execution.")
else:
    print("All initial walker positions have finite log-probabilities. Proceeding to MCMC.")


with Pool(nproc) as pool:
    print(f'Running burnin with {nburnin} iterations starting from {start_label} parameters')
    p0 = start_params 
    ep0 = np.zeros(len(p0)) + 0.01
    assert np.all(np.isfinite(start_params)), "start_params contains NaN or inf"
    # Generate walker positions around the starting point
    p0s = np.random.multivariate_normal(p0, np.diag(ep0)**2, size=nwalkers)

    print("Clipping all walker positions to be within prior ranges...")
    for i in range(len(prior)):
        min_val, max_val = prior[i]
        # Add a small buffer to avoid being exactly on the boundary
        buffer = 1e-10
        p0s[:, i] = np.clip(p0s[:, i], min_val + buffer, max_val - buffer)

    # Special clipping for pstream to [0, 1] if it's the first parameter
    p0s[:,0] = np.clip(p0s[:,0], 1e-10, 1 - 1e-10)
        
    start = time.time()
    es = emcee.EnsembleSampler(
        nwalkers, len(flat_p0_guess), stream_funcs.spline_lnprob_1D,
        args=(prior, phi1_spline_points, 
              tomcmc_stream.data.desi_data['VGSR'].values, tomcmc_stream.data.desi_data['VRAD_ERR'].values,
              tomcmc_stream.data.desi_data['FEH'].values, tomcmc_stream.data.desi_data['FEH_ERR'].values,
              tomcmc_stream.data.desi_data['PMRA'].values, tomcmc_stream.data.desi_data['PMRA_ERROR'].values,
              tomcmc_stream.data.desi_data['PMDEC'].values, tomcmc_stream.data.desi_data['PMDEC_ERROR'].values,
              tomcmc_stream.data.desi_data['phi1'].values, 
              True, False, True, spline_k, array_lengths,
              vgsr_trunc, feh_trunc, pmra_trunc, pmdec_trunc),
        pool=pool, backend=backend)
    PP = es.run_mcmc(p0s, nburnin)
    print(f'Took {(time.time()-start):.1f} seconds ({(time.time()-start)/60:.1f} minutes)')
    
    print(f'Now sampling with {nstep} iterations')
    es.reset()
    start = time.time()
    es.run_mcmc(PP.coords, nstep)
    print(f'Took {(time.time()-start):.1f} seconds ({(time.time()-start)/60:.1f} minutes)')
    
    chain = es.chain
Using optimized parameters as starting positions...
Validating initial walker positions with the log-probability function...
All initial walker positions have finite log-probabilities. Proceeding to MCMC.
Running burnin with 5000 iterations starting from optimized parameters
Clipping all walker positions to be within prior ranges...
Took 100.7 seconds (1.7 minutes)
Now sampling with 5000 iterations
Took 86.7 seconds (1.4 minutes)
CPU times: user 1min 49s, sys: 1min 5s, total: 2min 54s
Wall time: 3min 7s
In [138]:
# Create expanded parameter labels to match the flattened chain dimensions
indices = np.arange(1, no_of_spline_points + 1).astype(str)
velocity_labels = ['vgsr' + i for i in indices]
pmra_labels = ['pmra' + i for i in indices] 
pmdec_labels = ['pmdec' + i for i in indices]

expanded_param_labels = (['pstream'] + 
                        velocity_labels + 
                        ['lsigvgsr', 'feh1', 'lsigfeh'] +
                        pmra_labels + 
                        ['lsigpmra'] +
                        pmdec_labels +
                        ['lsigpmdec', 'bv', 'lsigbv', 'bfeh', 'lsigbfeh', 'bpmra', 'lsigbpmra', 'bpmdec', 'lsigbpmdec'])

Nrow = chain.shape[2]
fig, axes = plt.subplots(Nrow, figsize=(6,2*Nrow))

prior = np.array(prior)

for iparam,ax in enumerate(axes):
    for j in range(nwalkers):
        ax.plot(chain[j,:,iparam], lw=.5, alpha=.2)
        ax.set_ylabel(expanded_param_labels[iparam])
fig.tight_layout()
No description has been provided for this image
In [139]:
flatchain = es.flatchain
flatchain.shape
fig = corner.corner(flatchain, labels=expanded_param_labels, quantiles=[0.16,0.50,0.84], show_titles=True)
No description has been provided for this image
In [140]:
meds, errs = stream_funcs.process_chain(flatchain, labels = expanded_param_labels)
print(len(meds))
print(meds)
exp_flatchain = np.copy(flatchain)
for i, label in enumerate(meds.keys()):
    if label[0] == 'l':
        exp_flatchain[:,i]= 10 ** exp_flatchain[:,i]
exp_meds, exp_errs = stream_funcs.process_chain(exp_flatchain, labels = expanded_param_labels)
29
OrderedDict([('pstream', 0.016113677693207894), ('vgsr1', -277.34170393351195), ('vgsr2', -194.19233033700183), ('vgsr3', -197.42109258005962), ('vgsr4', -213.11724981821308), ('vgsr5', -226.0680732079289), ('lsigvgsr', -2.0118703659336354), ('feh1', -1.1772622839426974), ('lsigfeh', -0.9637764576119371), ('pmra1', -7.050739962662809), ('pmra2', -7.80538115473079), ('pmra3', -8.457955107720883), ('pmra4', -7.819741456933046), ('pmra5', -8.359696572916443), ('lsigpmra', -0.7824906806404727), ('pmdec1', -8.016673511885593), ('pmdec2', -8.808982819772375), ('pmdec3', -8.100639837588277), ('pmdec4', -6.591421140316544), ('pmdec5', -5.757658968457364), ('lsigpmdec', -1.5141550084186763), ('bv', 198.79777167149), ('lsigbv', 2.1485754850279326), ('bfeh', -1.48646972866319), ('lsigbfeh', -0.3727593478273431), ('bpmra', 19.078418382510815), ('lsigbpmra', 0.9409904484832435), ('bpmdec', -6.2922301214472185), ('lsigbpmdec', 0.5985347192022115)])
In [141]:
_, ep, em = stream_funcs.process_chain(flatchain, avg_error=False, labels = expanded_param_labels)
print(len(meds))
print(meds)
exp_flatchain = np.copy(flatchain)
for i, label in enumerate(meds.keys()):
    if label[0] == 'l':
        exp_flatchain[:,i]= 10 ** exp_flatchain[:,i]
_, exp_ep, exp_em = stream_funcs.process_chain(exp_flatchain, avg_error=False, labels = expanded_param_labels)
29
OrderedDict([('pstream', 0.016113677693207894), ('vgsr1', -277.34170393351195), ('vgsr2', -194.19233033700183), ('vgsr3', -197.42109258005962), ('vgsr4', -213.11724981821308), ('vgsr5', -226.0680732079289), ('lsigvgsr', -2.0118703659336354), ('feh1', -1.1772622839426974), ('lsigfeh', -0.9637764576119371), ('pmra1', -7.050739962662809), ('pmra2', -7.80538115473079), ('pmra3', -8.457955107720883), ('pmra4', -7.819741456933046), ('pmra5', -8.359696572916443), ('lsigpmra', -0.7824906806404727), ('pmdec1', -8.016673511885593), ('pmdec2', -8.808982819772375), ('pmdec3', -8.100639837588277), ('pmdec4', -6.591421140316544), ('pmdec5', -5.757658968457364), ('lsigpmdec', -1.5141550084186763), ('bv', 198.79777167149), ('lsigbv', 2.1485754850279326), ('bfeh', -1.48646972866319), ('lsigbfeh', -0.3727593478273431), ('bpmra', 19.078418382510815), ('lsigbpmra', 0.9409904484832435), ('bpmdec', -6.2922301214472185), ('lsigbpmdec', 0.5985347192022115)])
In [142]:
i = 0
# print("{:<10} {:>10} {:>10} {:>10} {:>10}".format('param','med','err','exp(med)','exp(err)'))
print("{:<10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format('param','med', 'em','ep','exp(med)', 'exp(em)','exp(ep)'))
print('--------------------------------------------------------------------------------------')
for label,v in meds.items():
    # if label[:8] == 'lpstream':
    #     print("{:<10} {:>10.3f} {:>10.3f} {:>10.5f} {:>10.5f}".format(label,v,errs[label], np.e**v, np.log(10)*(np.e**v)*errs[label]))
    if label[0] == 'l':
        # print("{:<10} {:>10.3f} {:>10.3f} {:>10.3f} {:>10.3f} ".format(label,v,errs[label], exp_meds[label], exp_errs[label]))
        print("{:<10} {:>10.3f} {:>10.3f} {:>10.3f} {:>10.3f} {:>10.3f} {:>10.3f}".format(label,v,em[label],ep[label], exp_meds[label], exp_em[label], exp_ep[label]))
    else:
        print("{:<10} {:>10.3f} {:>10.3f} {:>10.3f}".format(label, v, em[label], ep[label]))
    i += 1
param             med         em         ep   exp(med)    exp(em)    exp(ep)
--------------------------------------------------------------------------------------
pstream         0.016     -0.003      0.003
vgsr1        -277.342     -9.105     44.218
vgsr2        -194.192     -4.025      4.567
vgsr3        -197.421     -0.610      0.575
vgsr4        -213.117     -0.488      0.496
vgsr5        -226.068     -2.347      2.255
lsigvgsr       -2.012     -1.327      1.534      0.010     -0.009      0.323
feh1           -1.177     -0.027      0.028
lsigfeh        -0.964     -0.097      0.091      0.109     -0.022      0.025
pmra1          -7.051     -2.055      2.593
pmra2          -7.805     -0.441      0.441
pmra3          -8.458     -0.071      0.073
pmra4          -7.820     -0.059      0.063
pmra5          -8.360     -0.204      0.206
lsigpmra       -0.782     -0.153      0.113      0.165     -0.049      0.049
pmdec1         -8.017     -6.312      1.553
pmdec2         -8.809     -0.802      0.544
pmdec3         -8.101     -0.049      0.051
pmdec4         -6.591     -0.043      0.039
pmdec5         -5.758     -0.177      0.182
lsigpmdec      -1.514     -1.838      0.638      0.031     -0.030      0.102
bv            198.798    -78.331     67.680
lsigbv          2.149     -0.051      0.037    140.791    -15.499     12.587
bfeh           -1.486     -0.011      0.011
lsigbfeh       -0.373     -0.008      0.008      0.424     -0.007      0.008
bpmra          19.078     -1.527      0.679
lsigbpmra       0.941     -0.015      0.009      8.730     -0.297      0.182
bpmdec         -6.292     -0.112      0.121
lsigbpmdec      0.599     -0.012      0.013      3.968     -0.111      0.116
In [143]:
#Calculate membership probabilities using the new spline_memprob_1D function
from stream_functions import spline_memprob_1D

# Get the data from the stream object that was optimized
data = tomcmc_stream.data.desi_data

# Extract the relevant parameters from the MCMC results
theta_final = list(meds.values())  # Use the median parameters from MCMC

# Calculate membership probabilities
stream_prob = stream_funcs.spline_memprob_1D(
    theta=theta_final,
    spline_x_points=phi1_spline_points,
    pstream_spline_x_points=phi1_spline_points,  # Use same spline points for pstream
    lsig_vgsr_spline_points=phi1_spline_points,  # Use same spline points for lsig_vgsr
    vgsr=data['VGSR'].values,
    vgsr_err=data['VRAD_ERR'].values,
    feh=data['FEH'].values,
    feh_err=data['FEH_ERR'].values,
    pmra=data['PMRA'].values,
    pmra_err=data['PMRA_ERROR'].values,
    pmdec=data['PMDEC'].values,
    pmdec_err=data['PMDEC_ERROR'].values,
    phi1=data['phi1'].values,
    trunc_fit=True,  # Use truncated fitting as in your setup
    reshape_arr_shape=array_lengths,
    k=spline_k,
    vgsr_trunc=vgsr_trunc,
    feh_trunc=feh_trunc,
    pmra_trunc=pmra_trunc,
    pmdec_trunc=pmdec_trunc
)

print(f"Calculated membership probabilities for {len(stream_prob)} stars")
print(f"Membership probabilities range from {np.min(stream_prob):.3f} to {np.max(stream_prob):.3f}")
print(f"Mean membership probability: {np.mean(stream_prob):.3f}")
print(f"Stars with >50% probability: {len(stream_prob[stream_prob > 0.5])}")
print(f"Stars with >70% probability: {len(stream_prob[stream_prob > 0.7])}")
print(f"Stars with >90% probability: {len(stream_prob[stream_prob > 0.9])}")
Calculated membership probabilities for 1764 stars
Membership probabilities range from 0.000 to 1.000
Mean membership probability: 0.015
Stars with >50% probability: 27
Stars with >70% probability: 25
Stars with >90% probability: 24
In [144]:
# Create membership probability histogram
fig, ax = plt.subplots(1, 1, figsize=(7, 6))

#-----------------------#
min_prob = 0.5  # Change this to adjust the stream probability threshold
#-----------------------#

ax.hist(stream_prob, bins=50, alpha=0.7, color='blue', 
        label=f'All Stars ({len(stream_prob)} total)', density=True)

ax.axvline(0.5, color='black', linestyle='--', alpha=0.7, label='50% threshold')

# Format plot
ax.set_xlabel('Membership Probability', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title(f'Membership Probability Distribution for {streamName}', fontsize=14)
ax.legend()
ax.grid(alpha=0.3)
stream_funcs.plot_form(ax)

# log yaxis
ax.set_yscale('log')

plt.tight_layout()
plt.show()

# Print statistics for different probability thresholds
thresholds = [0.3, 0.5, 0.9]
print("\nMembership statistics:")
print("Threshold | Count")
print("----------|-------")
for thresh in thresholds:
    count = len(stream_prob[stream_prob >= thresh])
    print(f"   ≥{thresh:3.1f}   | {count:5d} ")
No description has been provided for this image
Membership statistics:
Threshold | Count
----------|-------
   ≥0.3   |    28 
   ≥0.5   |    27 
   ≥0.9   |    24 
In [146]:
importlib.reload(st)

plt_enhanced = st.StreamPlotter(MCMeta) 
plt_enhanced.plot_params['sf_in_desi']['alpha'] = 0.5
plt_enhanced.plot_params['sf_in_desi']['zorder'] = 10
plt_enhanced.plot_params['background']['alpha'] = 0.4
plt_enhanced.plot_params['background']['s'] = 1

fig, ax = plt_enhanced.sixD_plot(
    showStream=False, 
    show_sf_only=True, 
    background=True,
    show_initial_splines=False,
    show_optimized_splines=True,
    show_sf_errors=False,
    show_mcmc_splines=True,
    show_membership_prob=True, stream_prob=stream_prob, min_prob=min_prob
)

#plt.tight_layout()
plt.show()
No description has been provided for this image